# This script works with run_simulation.slurm to execute a single simulation from a slurm array.
## Usage: Rscript run_simulation.R <simulation> <slurm_array_id>

# I define all simulations in here, and call them by name from args.
## Simulation 1: Mean-only estimation
## Simulation 2: Logistic regression with surrogate outcome
## Simulation 3: Linear regression with surrogate predictor

# GLOBAL VARIABLES
N_SIMS_PER_DESIGN <- 30 # Controls how many iterations we do per design
DEV <- FALSE

# Source simulation_utils
source("simulation_utils.R")


main <- function(simulation=1, slurm_array_id=1){
    src_data <- load_data(dev=DEV)
    if (simulation==1) { # Simulation 1: Mean-only estimation
        # Define out file
        out_file <- paste0("results/sim1/", sprintf("%03d", slurm_array_id), ".rds")
        ## Check if exists; if so then exit:
        if (file.exists(out_file)) {
            print(paste("File exists:", out_file))
            return()
        }

        ## Common to all simulations
        fml <- formula("fire ~ 1")
        surrogate_var <- "fire"
        estimator_name <- "lm" # This is a kludge, but we are doing proportions

        ## Design grid
        design_grid <- expand.grid(
            # n_label = c(100, 200, 300, 400, 500, 600),
            n_label = c(100, 250, 500, 750, 1000, 2000),
            acc=c(0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99),
            bal=c(0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99)
        )
        design <- design_grid[slurm_array_id,]
        print(design)
        error_func <- function(arr){binary_error(arr, design$acc, design$bal)}

        ## Run simulation
        sim_result <- repeat_simulation(
            src_data = src_data,
            fml = fml,
            surrogate_var = surrogate_var,
            n_label = design$n_label,
            estimator_name = estimator_name,
            error_func = error_func,
            n_sims = N_SIMS_PER_DESIGN
        )

        ## Save results
        saveRDS(sim_result,
                file=paste0("results/sim1/", sprintf("%03d", slurm_array_id), ".rds"))



    } else if (simulation==2) { # Simulation 2: Continuous outcome
        ## Common to all simulations
        fml <- formula("police ~ violence")
        surrogate_var <- "police"
        estimator_name <- "logit"

        ## Design grid - error is scaled to outcome
        design_grid <- expand.grid(
            # n_label = c(100, 200, 300, 400, 500, 600),
            n_label = c(100, 250, 500, 750, 1000, 2000),
            acc=c(0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99),
            bal=c(0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99)
        )
        design <- design_grid[slurm_array_id,]
        print(design)
        error_func <- function(arr){binary_error(arr, design$acc, design$bal)}

        ## Run simulation
        sim_result <- repeat_simulation(
            src_data = src_data,
            fml = fml,
            surrogate_var = surrogate_var,
            n_label = design$n_label,
            estimator_name = estimator_name,
            error_func = error_func,
            n_sims = N_SIMS_PER_DESIGN
        )

        ## Save results
        saveRDS(sim_result,
                file=paste0("results/sim2/", sprintf("%03d", slurm_array_id), ".rds"))

    } else if (simulation==3) { # Simulation 3: continuous outcome, binary surrogate
        ## 
        fml <- formula("violence ~ sign + photo + fire + police + children + group_20 + flag + night + shouting")
        surrogate_var <- "fire"
        estimator_name <- "lm"

        ## Design grid is a bit different
        design_grid <- expand.grid(
            n_label = c(100, 250, 500, 750, 1000, 2000),
            acc=c(0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99),
            bal=c(0.5, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99)
        )
        design <- design_grid[1,]
        error_func <- function(arr){binary_error(arr, design$acc, design$bal)}

        sim_result <- repeat_simulation(
            src_data = src_data,
            fml = fml,
            surrogate_var = surrogate_var,
            n_label = design$n_label,
            estimator_name = estimator_name,
            error_func = error_func,
            n_sims = N_SIMS_PER_DESIGN
        )

        ## Save results
        saveRDS(sim_result,
                file=paste0("results/sim3/", sprintf("%03d", slurm_array_id), ".rds"))

    } else if (simulation==4) { # Simulation 2: Continuous predictor
        ## Common to all simulations
        fml <- formula("police ~ violence")
        surrogate_var <- "violence"
        estimator_name <- "logit"

        ## Design grid - error is scaled to outcome
        design_grid <- expand.grid(
            n_label = c(100, 200, 300, 400, 500, 600),
            bias_scale=c(0.01, 0.05, 0.1, 0.5, 1.0, 2.0),
            sd_scale=seq(0.0, 1.0, by=0.1)
        )
        design <- design_grid[slurm_array_id,]
        print(design)
        error_func <- function(arr){continuous_error(arr, bias=mean(arr)*design$bias_scale, sd=sd(arr)*design$sd_scale)}

        ## Run simulation
        sim_result <- repeat_simulation(
            src_data = src_data,
            fml = fml,
            surrogate_var = surrogate_var,
            n_label = design$n_label,
            estimator_name = estimator_name,
            error_func = error_func,
            n_sims = N_SIMS_PER_DESIGN
        )

        ## Save results
        saveRDS(sim_result,
                file=paste0("results/sim4/", sprintf("%03d", slurm_array_id), ".rds"))

    } else {
        stop("Invalid simulation number")

    }
}


# Run main
main(simulation=as.numeric(commandArgs(trailingOnly=TRUE)[1]),
     slurm_array_id=as.numeric(commandArgs(trailingOnly=TRUE)[2]))